feat(moe): add bf16 DeepEP-normal MoE path via DeepGEMM grouped GEMM#1111
feat(moe): add bf16 DeepEP-normal MoE path via DeepGEMM grouped GEMM#1111Tanmo-ai wants to merge 5 commits into
Conversation
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/1 · P2/1 · P3/0 Blocking IssuesP1
Non-blocking SuggestionsP2
Checklist Violations (5 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/3 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (7 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
e4564ea to
68b0a44
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/0 · P3/0 lgtm ready to ci Checklist ✅ (104 items passed)Strengths
|
2012875 to
8e0e71d
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/2 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (2 fail / 56 total)General Principles Checklist
Strengths
|
8e0e71d to
0a39628
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/3 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (5 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
0a39628 to
e284541
Compare
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/1 · P2/1 · P3/0 Blocking IssuesP1
Non-blocking SuggestionsP2
Checklist Violations (4 fail / 60 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
e284541 to
cc2025a
Compare
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/1 · P2/1 · P3/0 Blocking IssuesP1
Non-blocking SuggestionsP2
Checklist Violations (3 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
cc2025a to
9ecf9d8
Compare
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/1 · P2/0 · P3/0 Blocking IssuesP1
Checklist Violations (5 fail / 104 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
9ecf9d8 to
2658cff
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/0 · P3/0 lgtm ready to ci Checklist ✅ (56 items passed)Strengths
|
|
internal source has been updated, please review the changes! |
Replace runtime signature introspection with explicit validation: bf16 grouped GEMM only supports compiled_dims='nk', reject others with ValueError. Matches PR alibaba#1111 approach. Remove _has_param helper and all inspect.signature usage — eliminates unintrospectable callable, **kwargs, and positional-vs-keyword edge cases.
2658cff to
81d0d35
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/3 · P3/1 lgtm ready to ci Non-blocking SuggestionsP2
P3
Checklist Violations (2 fail / 92 total)RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
c9bfddc to
92943ce
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/5 · P3/1 lgtm ready to ci Non-blocking SuggestionsP2
P3
Checklist Violations (3 fail / 92 total)RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
|
CI dispatcher could not find a native This can happen if the PR was opened before the CI architecture change, or if the original run was deleted. To fix: push any commit (even empty: |
Add an opt-in expert path for unquantized (bf16) MoE under DeepEP normal mode that uses DeepGEMM grouped GEMM instead of the Triton fused_moe_kernel. - DeepGemmBf16HybridExecutor: runtime-dispatches between a masked 3D layout (small token count / decode) and a contiguous flat layout (large token count / prefill) for better memory utilization. - ep_scatter_bf16 / ep_scatter_v2_bf16: bf16 variants of the existing fp8 scatter kernels (flat -> contiguous, flat -> 3D masked). - CudaNoQuantDpNormalDeepGemmStrategy: opt-in only, selected via --moe_strategy no_quant_dp_normal_deepgemm, gated on bf16 + has_deep_gemm + SM>=9 + no CUDA graph. It is NOT part of "auto" selection, so the default MoE path on existing CUDA deployments is unchanged. deepgemm_wrapper.py changes are backward-compatible and do not affect existing fp8/bf16 callers: - has_deep_gemm() re-checks until the first successful import (then caches True) instead of caching the first result; for normal processes where deep_gemm is importable at import time it returns True on the first call exactly as before. Needed for spawned subprocesses whose sys.path is set up after module import. - Symbol resolution is deferred from import-time to first use (_ensure_initialized); functionally identical, only lazy. - bf16 grouped-GEMM legacy fallback names corrected to the real deep_gemm symbols (gemm_bf16_bf16_bf16_nt*). resolve_symbol() tries the standard name first, so existing resolution is unchanged; this only makes the previously dormant bf16 path resolvable. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…r e2e test_ep_scatter_bf16: the scatter kernels assign output slots with a non-deterministic tl.atomic_add, so row order within an expert is not fixed. Rewrite all checks to be order-independent by following output_index (the authoritative token->slot map the gather stage uses) instead of assuming a token-sequential layout. This replaces the previously order-sensitive torch.equal comparisons that could spuriously fail. Also: - fix the roundtrip tests to use hidden_size % 512 == 0 (ep_gather BLOCK_D=512); - size the contiguous stress test's per-expert capacity from the real routing histogram (bincount, aligned) instead of a fixed count, matching the executor's allocation and avoiding under-allocation. deepgemm_bf16_hybrid_executor: new end-to-end test for the bf16 DeepEP-normal hybrid executor (scatter -> grouped GEMM -> silu_and_mul -> grouped GEMM -> gather with router weight), covering both the masked (small token count) and contiguous (large token count) runtime paths against a plain-torch reference. Tagged open_skip + H20 (requires deep_gemm + SM>=9). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…d bf16 init Review fixes on the bf16 DeepEP-Normal deepgemm MoE path. All changes are confined to the opt-in no_quant_dp_normal_deepgemm path, the bf16 deep_gemm wrappers, and tests — the fp8 path and other default/cross-arch paths are unaffected. - DeepGemmBf16HybridExecutor.execute: handle empty rank before dispatch. DeepEP small-batch / skewed routing can leave a rank with token_num == 0; that would otherwise enter the masked path with alignment == 0 and launch 0-grid Triton scatter / 0-size DeepGEMM. Return an empty same-shape [0, K] bf16 output. - DeepGemmBf16HybridExecutor (contiguous path): build the per-expert token-count tensor with .to(device=hidden_states.device, non_blocking=True) instead of .cuda() so it honors the hidden-states device invariant. - deepgemm_wrapper: decouple bf16 symbol resolution from the fp8 path. Previously _ensure_initialized() resolved fp8 AND bf16 symbols together, so an older deep_gemm build missing the bf16 symbols would raise from _ensure_initialized() and break the fp8 wrappers. Now _ensure_initialized() resolves only _FP8_SYMBOLS (raises if missing — fp8 is core), while _ensure_bf16_initialized() resolves _BF16_SYMBOLS independently and tolerantly (missing -> impls stay None, never propagate). bf16 wrappers call _ensure_bf16_initialized(); fp8 wrappers keep _ensure_initialized(). has_deep_gemm_bf16_grouped() reports False (never raises) when the bf16 symbols are unavailable. - deepgemm_wrapper bf16 grouped wrappers: reject a non-default compiled_dims explicitly (NotImplementedError) instead of silently ignoring it; the wrapper does not forward compiled_dims (forwarding perturbs bf16 numerics on this shared path). No current caller passes a non-"nk" value. - CudaNoQuantDpNormalDeepGemmStrategy: fail fast at selection via has_deep_gemm_bf16_grouped(), and gate on the explicit opt-in moe_strategy FIRST with a short-circuit return (ConditionChecker does not stop at the first failed check, so this keeps the probe from running for non-opt-in / "auto" configs). - Tests: empty-rank (token_num==0) executor cases (ep1 + ep2); ep_size>1 executor coverage (rank 0/1, _to_local_expert_ids mapping + masking) vs a per-rank torch reference; strategy selection pos/neg; has_deep_gemm_bf16_grouped no-raise and _ensure_bf16_initialized tolerance; executor test skip uses has_deep_gemm_bf16_grouped() to match the gating. The ep_kernels contiguous padding-row m_indices contract is left to the feature kernel owner (padding output is discarded by the gather; a real fix needs a kernel signature change + a deep_gemm -1 skip contract). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
92943ce to
2594c21
Compare
AI Code Review - PR #1111Status: LGTM Summary: P0/0 · P1/0 · P2/5 · P3/3 lgtm ready to ci Non-blocking SuggestionsP2
P3
Checklist Violations (4 fail / 92 total)General Principles Checklist
RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/0 · P2/1 · P3/2 Non-blocking SuggestionsP2
P3
Checklist Violations (3 fail / 56 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
03294ae to
363f9c3
Compare
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/0 · P2/1 · P3/2 Non-blocking SuggestionsP2
P3
Checklist Violations (2 fail / 56 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
12fdee5 to
cfe0fab
Compare
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/0 · P2/2 · P3/5 Non-blocking SuggestionsP2
P3
避免把上游 import 崩溃也当成「bf16 不可用」静默处理。
Checklist Violations (6 fail / 56 total)General Principles Checklist
RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
…nearAttn decode kernels
Root cause: the GatedDeltaNet decode kernels silently corrupt data or crash when
a multi-turn conversation accumulates enough tokens that the sequence fills all
allocated KV-cache blocks. Two independent block_map offsets can go out of bounds:
read path: read_block_offset = (sequence_length - 1) // SEQ_SIZE_PER_BLOCK
write path: write_block_offset = sequence_length // SEQ_SIZE_PER_BLOCK
When sequence_length reaches block_map.size(1) * SEQ_SIZE_PER_BLOCK (+1 for the
write offset, +2 for the read offset), the offset reaches block_map.size(1) —
one past the end of the block_map row. The OOB block_map read yields a garbage
block id; every downstream state load/store then computes an out-of-bounds
address, corrupting the KV cache or faulting depending on the platform (some
hardware evaluates load addresses even for masked-off lanes, so a masked load
does not protect against this).
Affected kernels:
_causal_conv1d_update_kernel (causal_conv1d.py)
fused_recurrent_gated_delta_rule_fwd_kernel (fused_recurrent.py)
Fix:
causal_conv1d.py: read path — clamp read_block_offset to the last allocated
block (stride_block_map - 1) before the block_map load.
write path — replace the masked tl.load with an explicit
`if write_block_offset < stride_block_map:` branch that
fully skips address evaluation on OOB.
fused_recurrent.py: write path — wrap the write-state block in
`if write_block_offset < max_block_size:` (the read path
already guards via `if read_block_id <= 0: return`).
Test: TestCausalConv1dMaxSeqLenBoundary covers both boundaries —
- write offset OOB at sequence_length == MAX_SEQ_LEN + 1
- read offset OOB at sequence_length == MAX_SEQ_LEN + 2 (garbage block id fed
into the conv_state load)
Both verify no crash and finite output; the last in-bounds step stays correct.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
cfe0fab to
4bd6f1f
Compare
…aph capture 32k/64k seq_len with DP=8 and 8 capture batch sizes requires 40-50 min for CUDA Graph capture on PPU. The previous 1600s default (and 3000s intermediate fix) was insufficient.
AI Code Review - PR #1111Status: BLOCKING Summary: P0/0 · P1/0 · P2/3 · P3/3 Non-blocking SuggestionsP2
P3
Checklist Violations (10 fail / 56 total)General Principles Checklist
Strengths
|
Summary
Add an opt-in expert path for unquantized (bf16) MoE under DeepEP normal mode
that uses DeepGEMM grouped GEMM instead of the Triton fused_moe_kernel.
Changes
(small token count / decode) and a contiguous flat layout (large token count /
prefill) for better memory utilization.
scatter kernels (flat → contiguous, flat → 3D masked).
--moe_strategy no_quant_dp_normal_deepgemm, gated on bf16 + has_deep_gemm +SM≥9 + no CUDA graph. It is not part of "auto" selection, so the default
MoE path on existing deployments is unchanged.
deepgemm_wrapper.py changes (backward-compatible)
These changes enable the bf16 grouped-GEMM path and do not affect existing
fp8/bf16 callers:
(
gemm_bf16_bf16_bf16_nt*).resolve_symbol()tries the standard name firstand only falls back, so existing resolution is unchanged — this only makes the
previously dormant bf16 path resolvable. The stale
compiled_dimsargument isdropped from the contiguous/masked bf16 calls to match the actual deep_gemm
signature.
_ensure_initialized) instead ofat import time. Functionally identical, only lazy: the same symbols are
resolved, just on the first actual GEMM call.
True) instead of caching the first result. For normal processes where
deep_gemmis importable at import time it returns True on the first callexactly as before; this only adds resilience when the package becomes
importable slightly later, and does not change existing behavior.
Testing
test_ep_scatter_bf16.pycovers bf16 scatter kernel correctness.